import os
# Set global environment variables
os.environ["TORCH_HOME"] = "/GGboy/cache"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import torch
import numpy as np
import pandas as pd
import random
import cv2
from tqdm import tqdm
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math

# For SAM (Segment Anything Model)
from segment_anything import sam_model_registry, SamPredictor
import urllib.request # For downloading SAM weights

# For Grad-CAM
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Import for new transformation logic
import torchvision.transforms.functional as TF




class CWT_Transform:
    def __init__(self, num_scale=20, num_block=2, rotation_probability=0.5, max_angle=20, range_max=1.3):
        self.num_scale = num_scale
        self.num_block = num_block
        self.max_angle = max_angle
        self.range_max = range_max
        self.rotation_probability = rotation_probability # Added for potential future use or consistency

    def get_length(self, length, num_blocks):
        """Calculates random lengths for splitting a dimension into num_blocks parts.
        Ensures each block length is at least 5 and the sum equals the original length.
        """
        rand = np.random.uniform(size=num_blocks)
        rand_norm = np.round(rand / rand.sum() * (length - num_blocks * 5)).astype(np.int32)
        rand_norm[rand_norm < 5] = 5 # Ensure minimum block size
        rand_norm[-1] += length - rand_norm.sum() # Adjust last block to sum up correctly
        return tuple(rand_norm)

    def split_image_into_blocks(self, img_tensor, num_blocks=(2, 2)):
        """Splits the image tensor into multiple blocks."""
        # Ensure img_tensor has a batch dimension if missing (C, H, W -> 1, C, H, W)
        if img_tensor.ndim == 3:
            img_tensor = img_tensor.unsqueeze(0)

        _, _, H, W = img_tensor.shape
        row_lengths = self.get_length(H, num_blocks[0])
        col_lengths = self.get_length(W, num_blocks[1])

        blocks = []
        start_h = 0
        for row in row_lengths:
            start_w = 0
            for col in col_lengths:
                if row > 0 and col > 0:
                    block = img_tensor[:, :, start_h:start_h + row, start_w:start_w + col]
                else:
                    block = None # Handle cases where block size might be zero
                blocks.append(block)
                start_w += col
            start_h += row
        
        return blocks, row_lengths, col_lengths

    def random_crop_edges(self, blocks, H, W):
        """Randomly crops the edges of a block to match target dimensions H, W."""
        # Ensure blocks has a batch dimension (B, C, H_block, W_block)
        if blocks.ndim == 3:
            blocks = blocks.unsqueeze(0)

        B, C, scaled_H, scaled_W = blocks.shape

        # Calculate crop boundaries
        top = 0 if scaled_H == H else random.randint(0, max(0, scaled_H - H))
        left = 0 if scaled_W == W else random.randint(0, max(0, scaled_W - W))
        bottom = top + H
        right = left + W

        cropped_blocks = blocks[:, :, top:bottom, left:right]

        # If after cropping, dimensions don't match exactly (e.g., due to rounding or small input),
        # interpolate to the target size. This is a fallback to ensure dimensions are correct.
        if cropped_blocks.shape[2:] != (H, W):
            cropped_blocks = torch.nn.functional.interpolate(cropped_blocks, size=(H, W), mode='bilinear', align_corners=False)

        return cropped_blocks

    def generate_scale_factors(self, num_blocks, range_min=1.0, range_max=1.3, num_choices=41):
        """Generates a list of scale factors with maximized variance."""
        # Ensure range_min <= range_max
        if range_min > range_max:
            range_min, range_max = range_max, range_min

        scale_choices = [round(range_min + i * (range_max - range_min) / (num_choices - 1), 2) for i in range(num_choices)]
        random.shuffle(scale_choices)
        return scale_choices[:num_blocks]

    def transform_block(self, block_batch, scale_factor, i, idx_to_rotate, H_orig, W_orig):
        """Applies scaling, rotation, and cropping to a single block."""
        if block_batch is None or block_batch.shape[2] < 2 or block_batch.shape[3] < 2:
            # Return original block if it's too small or invalid
            return block_batch if block_batch is not None else torch.zeros((1, 3, H_orig, W_orig), device=block_batch.device)

        B, C, H_block, W_block = block_batch.shape

        # First, downscale
        scaled_H_down = max(2, int(H_block / scale_factor))
        scaled_W_down = max(2, int(W_block / scale_factor))
        resized_blocks = torch.nn.functional.interpolate(
            block_batch, size=(scaled_H_down, scaled_W_down), mode='bilinear', align_corners=False
        )
        
        # Then, upscale (this achieves the "zoom in/out" effect if scale_factor is not 1.0)
        scaled_H_up = int(H_block * scale_factor)
        scaled_W_up = int(W_block * scale_factor)
        # Ensure upscaled dimensions are at least 2 for valid interpolation
        scaled_H_up = max(2, scaled_H_up)
        scaled_W_up = max(2, scaled_W_up)

        scaled_blocks = torch.nn.functional.interpolate(
            resized_blocks, size=(scaled_H_up, scaled_W_up), mode='bilinear', align_corners=False
        )

        # Apply rotation if this block is selected
        if i in idx_to_rotate:
            angle = random.uniform(-self.max_angle, self.max_angle)
            scaled_blocks = TF.rotate(
                scaled_blocks,
                angle=angle,
                expand=False,
                center=(scaled_W_up / 2, scaled_H_up / 2),
                fill=0 # Fill rotated empty areas with 0 (black)
            )

        # Randomly crop to original block size (H_block, W_block)
        # Pass the *original* block dimensions to random_crop_edges for the target size
        transformed_block_final = self.random_crop_edges(scaled_blocks, H_block, W_block)

        return transformed_block_final

    def __call__(self, x, **kwargs):
        """Applies the block-wise transformation to the input image batch.
        Input x is (B, C, H, W)
        Output is (B * num_scale, C, H, W)
        """
        B, C, H, W = x.shape
        transformed_images_batch = []

        # Process each image in the batch (assuming B=1 for this script's usage)
        for img_idx in range(B):
            single_image_tensor = x[img_idx:img_idx+1] # (1, C, H, W)
            
            for _ in range(self.num_scale):
                blocks, row_lengths, col_lengths = self.split_image_into_blocks(single_image_tensor, num_blocks=(self.num_block, self.num_block))
                
                scale_factors = self.generate_scale_factors(len(blocks), range_max=self.range_max)
                
                transformed_blocks_list = []

                # Determine which blocks to rotate
                if len(blocks) >= 1:
                    num_to_rotate = random.randint(1, len(blocks))
                    idx_to_rotate = random.sample(range(len(blocks)), num_to_rotate)
                else:
                    idx_to_rotate = []
                    
                # Transform each block
                for i, block in enumerate(blocks):
                    if block is not None:
                        # Get original dimensions of this specific block to pass to transform_block for cropping
                        # This assumes that 'blocks' retains original block dimensions upon splitting, which it does.
                        H_orig_block, W_orig_block = block.shape[2:] 
                        scale_factor = scale_factors[i]
                        transformed_block = self.transform_block(block, scale_factor, i, idx_to_rotate, H_orig_block, W_orig_block)
                        transformed_blocks_list.append(transformed_block)
                    else:
                        # Append a placeholder for None blocks to maintain index consistency
                        transformed_blocks_list.append(None)


                # Reconstruct the image from transformed blocks
                transformed_image_reconstructed = torch.zeros_like(single_image_tensor)
                start_h, block_idx = 0, 0
                for row_len in row_lengths:
                    start_w = 0
                    for col_len in col_lengths:
                        current_block = transformed_blocks_list[block_idx]
                        if current_block is not None:
                            # Ensure the block to be placed matches the expected dimensions
                            # This is a critical check to avoid size mismatch errors during reconstruction
                            expected_H = row_len
                            expected_W = col_len
                            if current_block.shape[2] != expected_H or current_block.shape[3] != expected_W:
                                current_block = torch.nn.functional.interpolate(current_block, size=(expected_H, expected_W), mode='bilinear', align_corners=False)
                            
                            transformed_image_reconstructed[:, :, start_h:start_h + row_len, start_w:start_w + col_len] = current_block
                        start_w += col_len
                        block_idx += 1
                    start_h += row_len
                
                transformed_images_batch.append(transformed_image_reconstructed)

        # Concatenate all transformed images into a single batch
        return torch.cat(transformed_images_batch, dim=0)




# Input and output paths
input_folder = '/new_data/images'
label_csv_path = '/new_data/filtered_labels.csv'

# Read labels
df_labels = pd.read_csv(label_csv_path)
# labels_dict = {row['filename']: row['label'] for _, row in df_labels.iterrows()}
labels_dict = {row['filename']: row['label'] for _, row in df_labels.head(400).iterrows()}

# Image preprocessing for model inference
# transform_for_model = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     # Normalization for ResNet18 pretrained on ImageNet
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # Image preprocessing for Grad-CAM visualization (denormalize to [0,1] range)
# transform_for_cam_display = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# Model settings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- CHANGE: Using ResNet18 instead of InceptionV3 ---
# model = models.resnet18(pretrained=True).to(device)
# model = models.resnext50_32x4d(pretrained=True).to(device)
# model.eval()

# --- SAM Model Setup ---
SAM_CHECKPOINT_PATH = "/cache/hub/checkpoints/sam_vit_b_01ec64.pth" # Adjust path to your SAM checkpoint
SAM_MODEL_TYPE = "vit_b" # or "vit_l", "vit_b" depending on your checkpoint

print(f"Loading SAM model from {SAM_CHECKPOINT_PATH}...")
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH).to(device)
sam_predictor = SamPredictor(sam)
print("SAM model loaded.")


# --- Custom Transformation Parameters ---
NUM_SCALE = 20  # How many transformed images to generate for each input
NUM_BLOCK = 2   # Number of blocks per dimension (e.g., 2x2 grid)

# Initialize the new transformation class
cwt_transformer = CWT_Transform(num_scale=NUM_SCALE, num_block=NUM_BLOCK, max_angle=20, range_max=1.3)


# Load image function (unchanged from previous version)
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    # For model inference, use transform_for_model (includes normalization)
    img_tensor_for_model = transform_for_model(img).unsqueeze(0)
    # For SAM and CAM visualization, use transform_for_cam_display (only resize and to tensor, no normalization)
    img_tensor_for_cam = transform_for_cam_display(img).unsqueeze(0)
    return img_tensor_for_model, np.array(img), img_tensor_for_cam


# --- Image Preprocessing ---
transform_for_model = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_for_cam_display = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# --- Model Settings (ResNet18) ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = models.resnet18(pretrained=True).to(device)
# model = models.resnext50_32x4d(pretrained=True).to(device)
model = models.resnet101(pretrained=True).to(device)
model.eval()

# --- SAM Model Setup ---
# Using your provided local SAM checkpoint path.
# If this path is incorrect or the file is missing, the script will exit.
SAM_CHECKPOINT_PATH = "/cache/hub/checkpoints/sam_vit_b_01ec64.pth"
SAM_MODEL_TYPE = "vit_b"

print(f"Loading SAM model from {SAM_CHECKPOINT_PATH}...")
try:
    sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH).to(device)
    sam_predictor = SamPredictor(sam)
    print("SAM model loaded.")
except FileNotFoundError:
    print(f"Error: SAM checkpoint file not found at {SAM_CHECKPOINT_PATH}.")
    print("Please ensure the path is correct.")
    exit()
except Exception as e:
    print(f"Error loading SAM model: {e}")
    exit()


# --- Transformation Parameters ---
NUM_SCALE = 20  # Number of transformed images to generate per input
NUM_BLOCK = 2   # Number of blocks per dimension (e.g., 2x2 grid)

# Initialize the CWT_Transform class with defined parameters
cwt_transformer = CWT_Transform(num_scale=NUM_SCALE, num_block=NUM_BLOCK, max_angle=20, range_max=1.3)

# --- SAM Mask Coverage Threshold ---
SAM_COVERAGE_THRESHOLD = 0.1 # If SAM mask covers less than 20% of the image, skip the sample.

# --- Union IOU Filtering Threshold ---
IOU_FILTER_THRESHOLD = 0.2 # If calculated Union IOU is below this, do not count the sample.


# --- Image Loading Function ---
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    img_tensor_for_model = transform_for_model(img).unsqueeze(0)
    img_tensor_for_cam = transform_for_cam_display(img).unsqueeze(0)
    return img_tensor_for_model, np.array(img), img_tensor_for_cam

# --- Evaluation Metrics ---
total_union_iou_sum = 0
image_processed_count = 0 

skipped_due_to_low_coverage = 0
skipped_due_to_sam_failure = 0
skipped_due_to_low_union_iou = 0 

filtered_labels = []

# --- Main Processing Loop ---
for filename in tqdm(os.listdir(input_folder), desc="Processing Images"):
    if filename in labels_dict:
        ground_truth_label = labels_dict[filename]
        image_path = os.path.join(input_folder, filename)

        # Load and preprocess image
        input_image_tensor_for_model, original_image_np, input_image_tensor_for_cam = load_image(image_path)
        input_image_tensor_for_model = input_image_tensor_for_model.to(device)

        # --- 1. SAM Segmentation on Original Image ---
        sam_predictor.set_image(original_image_np)
        H, W, _ = original_image_np.shape
        input_point = np.array([[W // 2, H // 2]])
        input_label = np.array([1])

        try:
            masks, scores, logits = sam_predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )
            best_mask_idx = np.argmax(scores)
            sam_mask = masks[best_mask_idx].astype(np.uint8)
            sam_mask_resized = cv2.resize(sam_mask, (224, 224), interpolation=cv2.INTER_NEAREST)
            
            # Check SAM Mask Coverage
            sam_mask_coverage = np.sum(sam_mask_resized) / (224 * 224)
            # if sam_mask_coverage < SAM_COVERAGE_THRESHOLD:
            #     # print(f"Skipping {filename}: SAM mask coverage ({sam_mask_coverage:.2f}) is below threshold ({SAM_COVERAGE_THRESHOLD:.2f}).")
            #     skipped_due_to_low_coverage += 1
            #     continue # Skip to the next image

            sam_mask_tensor = torch.from_numpy(sam_mask_resized).float().unsqueeze(0).unsqueeze(0).to(device)

        except Exception as e:
            # print(f"SAM segmentation failed for {filename}: {e}. Skipping this image.")
            skipped_due_to_sam_failure += 1
            continue # Skip to next image if SAM fails

        # --- Generate NUM_SCALE Transformed Images ---
        batch_transformed_images = cwt_transformer(input_image_tensor_for_model)

        # --- 2. Grad-CAM on Generated Transformed Images & Calculate Union IOU ---
        target_layers = [model.layer4[-1]] # For ResNet18

        union_heatmap_mask_tensor = torch.zeros_like(sam_mask_tensor, dtype=torch.bool).to(device)

        successful_cams_for_union = 0

        # Loop through each transformed image in the batch
        for i in range(NUM_SCALE):
            single_transformed_image_tensor = batch_transformed_images[i:i+1].to(device)

            with torch.no_grad():
                outputs = model(single_transformed_image_tensor)
                # We need to know the predicted class to target Grad-CAM
                predicted_class = torch.argmax(F.softmax(outputs, dim=1), dim=1).item()

            targets = [ClassifierOutputTarget(predicted_class)]

            cam = GradCAM(model=model, target_layers=target_layers) # Removed use_cuda
            grayscale_cam = cam(input_tensor=single_transformed_image_tensor, targets=targets)
            grayscale_cam = grayscale_cam[0, :] # Remove batch dimension (H, W)

            heatmap_resized = cv2.resize(grayscale_cam, (224, 224))
            # Binarize heatmap: create a mask where heatmap values are above a certain threshold
            heatmap_mask_np = (heatmap_resized > np.max(heatmap_resized) * 0.6).astype(np.uint8)
            
            # Convert heatmap_mask_np to a Torch tensor (1, 1, H, W) and to boolean for OR operation
            current_heatmap_mask_tensor = torch.from_numpy(heatmap_mask_np).bool().unsqueeze(0).unsqueeze(0).to(device)
            
            # Perform element-wise OR operation to accumulate the union
            union_heatmap_mask_tensor = union_heatmap_mask_tensor | current_heatmap_mask_tensor
            successful_cams_for_union += 1

        # After processing all NUM_SCALE transformed images for this filename:
        if successful_cams_for_union > 0:
            sam_mask_bool = sam_mask_tensor.bool()

            intersection = (sam_mask_bool & union_heatmap_mask_tensor).sum().item()
            union = (sam_mask_bool | union_heatmap_mask_tensor).sum().item()

            union_iou = (intersection / (union + 1e-6)) # Add small epsilon
            
            # --- APPLY THE UNION IOU FILTERING LOGIC HERE ---
            # if union_iou < IOU_FILTER_THRESHOLD:
            #     # print(f"Skipping {filename}: Union IOU ({union_iou:.4f}) is below threshold ({IOU_FILTER_THRESHOLD:.2f}).")
            #     skipped_due_to_low_union_iou += 1
            #     # Do NOT add to total_union_iou_sum or image_processed_count if skipped
            # else:
            #     # ONLY if IOU is above threshold, count it towards the average
            #     total_union_iou_sum += union_iou
            #     image_processed_count += 1
            #     # print(f"Union IOU for {filename}: {union_iou:.4f}")
            #     filtered_labels.append({'filename': filename, 'label': ground_truth_label})
            total_union_iou_sum += union_iou
            image_processed_count += 1
            # print(f"Union IOU for {filename}: {union_iou:.4f}")
            filtered_labels.append({'filename': filename, 'label': ground_truth_label})
        else:
            # print(f"No successful Grad-CAM calculations for {filename}, skipping Union IOU.")
            pass

# Final overall average Union IOU and summary
print("\n--- Summary ---")
if image_processed_count > 0:
    overall_avg_union_iou = total_union_iou_sum / image_processed_count
    print(f"Overall Average Union IOU across VALID processed images: {overall_avg_union_iou:.4f} ({image_processed_count} images counted)")
else:
    print("No images were successfully processed for Union IOU calculation after filtering.")

print(f"Images skipped due to low SAM coverage (< {SAM_COVERAGE_THRESHOLD:.2f}): {skipped_due_to_low_coverage}")
print(f"Images skipped due to SAM segmentation failure: {skipped_due_to_sam_failure}")
print(f"Images skipped due to low Union IOU (< {IOU_FILTER_THRESHOLD:.2f}): {skipped_due_to_low_union_iou}")

# if filtered_labels:
#     df_filtered_labels = pd.DataFrame(filtered_labels)
#     output_csv_path = os.path.join(os.path.dirname(label_csv_path), 'filtered_labels.csv') 
#     df_filtered_labels.to_csv(output_csv_path, index=False)
#     print(f"\nSuccessfully saved {len(filtered_labels)} filtered image labels to {output_csv_path}")
# else:
#     print("\nNo images passed all filtering conditions. No new labels.csv file was created.")